context("tfdatasets")


test_succeeds("Use tfdatasets to train a keras model", {

  model <- keras_model_sequential( input_shape = 1) %>%
    layer_dense(units = 1)
  model %>% compile(loss='mse', optimizer='sgd')

  dataset <- tfdatasets::tensors_dataset(reticulate::tuple(list(1), list(1))) %>%
    tfdatasets::dataset_repeat(100) %>%
    tfdatasets::dataset_shuffle(buffer_size = 100) %>%
    tfdatasets::dataset_batch(10)

  model %>% fit(dataset, epochs = 2)
  evaluate(model, dataset)
  preds <- predict(model, dataset)

})

test_that("Error when specifying batch_size with tfdatasets", {
  skip_if_no_keras()
  # TODO: do tf.data datasets work a/ jax backend? torch backend?
  # if (!is_tensorflow_implementation())
  #   skip("Datasets need TensorFlow implementation.")

  model <- keras_model_sequential(input_shape = 1) %>%
    layer_dense(units = 1)
  model %>% compile(loss='mse', optimizer='sgd')

  dataset <- tfdatasets::tensors_dataset(reticulate::tuple(list(1), list(1))) %>%
    tfdatasets::dataset_repeat(100) %>%
    tfdatasets::dataset_shuffle(buffer_size = 100) %>%
    tfdatasets::dataset_batch(10)

  expect_error(
    model %>% fit(dataset, epochs = 2, batch_size = 5)
  )

})


test_succeeds("Works with tf$distribute", {

  strategy <- tensorflow::tf$distribute$MirroredStrategy()

  with (strategy$scope(), {

    model <- keras_model_sequential(input_shape = 1) %>%
      layer_dense(units = 1)
    model %>% compile(loss='mse', optimizer='sgd')

  })

  dataset <- tfdatasets::tensors_dataset(reticulate::tuple(list(1), list(1))) %>%
    tfdatasets::dataset_repeat(100) %>%
    tfdatasets::dataset_shuffle(buffer_size = 100) %>%
    tfdatasets::dataset_batch(10)

  # no clean way to silence the massively verbose output this test produces
  # TF_CPP_MIN_LOG_LEVEL is only effective if set before tf is initialized.
  # https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
  # https://github.com/tensorflow/tensorflow/issues/45157

  # (function() {
  #   ol <- Sys.getenv("TF_CPP_MIN_LOG_LEVEL")
  #   Sys.setenv("TF_CPP_MIN_LOG_LEVEL" = "3")
  #   on.exit(Sys.setenv("TF_CPP_MIN_LOG_LEVEL" = ol))
  model %>%
    fit(dataset, epochs = 10, verbose = 0)
  # })()

})
